from datasets import load_dataset, load_from_disk
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, TrainerCallback, AutoConfig
import argparse
import torch
import os
from copy import deepcopy
import numpy as np
import multiprocessing
import sys
from rmjpo_trainer import DPOTrainer
from rmjpo_config import DPOConfig


class MemoryCleanCallback(TrainerCallback):
    def on_step_end(self, args, state, control, **kwargs):
        torch.cuda.empty_cache() 



def main():
    parser = argparse.ArgumentParser(description="Pairwise RMJ-DPO Training Script")
    parser.add_argument("--model_name_or_path", type=str)
    parser.add_argument("--data_path", type=str)
    parser.add_argument("--loss_type", type=str)
    parser.add_argument("--per_device_train_batch_size", type=int)
    parser.add_argument("--gradient_accumulation_steps", type=int)
    parser.add_argument("--num_train_epochs", type=int)
    parser.add_argument("--beta", type=float)
    parser.add_argument("--max_length", type=int) 
    parser.add_argument("--max_prompt_length", type=int) 
    parser.add_argument("--max_completion_length", type=int) 
    parser.add_argument("--save_steps", type=int)
    parser.add_argument("--logging_steps", type=int)
    parser.add_argument("--learning_rate", type=float)
    parser.add_argument("--output_dir", type=str)
    parser.add_argument("--logging_dir", type=str)
    parser.add_argument("--dpo_finetuned_model_saved_dir", type=str)
    parser.add_argument("--output_reference_dispersion_local_dir", type=str)
    args = parser.parse_args()


    
    model_name_or_path = args.model_name_or_path

    model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16)
    model.config.use_cache = False  

    ref_model = deepcopy(model)
    ref_model.requires_grad_(False)  




    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    if tokenizer.chat_template is None:
        tokenizer.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"



    ds = load_from_disk(args.data_path)



    def process(row):
        row["prompt"] = tokenizer.apply_chat_template(row["chosen"][:-1], tokenize=False)
        row["chosen"] = tokenizer.apply_chat_template([row["chosen"][-1]], tokenize=False)
        row["rejected"] = tokenizer.apply_chat_template([row["rejected"][-1]], tokenize=False)

        # Add these lines to reduce bos_tokens.
        if row["chosen"].startswith(tokenizer.bos_token):
            row["chosen"] = row["chosen"][len(tokenizer.bos_token):]
        if row["rejected"].startswith(tokenizer.bos_token):
            row["rejected"] = row["rejected"][len(tokenizer.bos_token):]
        return row

    ds = ds.map(
        process,
        num_proc=100,
        num_proc=multiprocessing.cpu_count(),
        load_from_cache_file=True,
    )
    processed_dataset = ds

    



    dpo_args = DPOConfig(
    loss_type=args.loss_type,
    max_length=args.max_length,
    max_prompt_length=args.max_prompt_length,
    max_completion_length=args.max_completion_length,
    per_device_train_batch_size=args.per_device_train_batch_size,
    gradient_accumulation_steps=args.gradient_accumulation_steps,
    output_reference_dispersion=True,
    beta = args.beta,
    learning_rate = args.learning_rate,
    num_train_epochs=args.num_train_epochs,
    bf16=True,                  
    save_strategy="steps",     
    save_steps=args.save_steps,             
    output_dir=args.output_dir,
    gradient_checkpointing=True,
    logging_steps=args.logging_steps,     
    logging_dir=args.logging_dir,
    lr_scheduler_type = "cosine",
    warmup_ratio = 0.2
    )

    trainer = DPOTrainer(
        model=model,
        ref_model=ref_model,
        args=dpo_args,
        train_dataset=processed_dataset,
        processing_class=tokenizer,
    )

   

    trainer.add_callback(MemoryCleanCallback())

    trainer.train()

    trainer.save_model(args.dpo_finetuned_model_saved_dir)
    tokenizer.save_pretrained(args.dpo_finetuned_model_saved_dir)


    if trainer.output_reference_dispersion:
        all_reference_dispersion = torch.cat(trainer.all_reference_dispersion).float().numpy()
        np.save(args.output_reference_dispersion_local_dir, all_reference_dispersion)

if __name__ == "__main__":
    main()



